Skip to content

[1/N] Elastic EP Milestone 2#34861

Merged
tlrmchlsmth merged 40 commits intovllm-project:mainfrom
itayalroy:eep_m2_rebase
Feb 28, 2026
Merged

[1/N] Elastic EP Milestone 2#34861
tlrmchlsmth merged 40 commits intovllm-project:mainfrom
itayalroy:eep_m2_rebase

Conversation

@itayalroy
Copy link
Copy Markdown
Contributor

@itayalroy itayalroy commented Feb 19, 2026

This PR completes the work in #26278 originally authored by @libertyeagle, who designed and implemented the core architecture for elastic EP milestone 2. In collaboration with @tlrmchlsmth, we rebased the PR on top of latest vLLM main, resolved all conflicts and remaining issues to help get it ready for merge. All elastic EP features (scale up, scale down, and serving requests between scaling events) have been tested with multiple EP backends.

See Elastic EP RFC for more details: #20323

@github-actions
Copy link
Copy Markdown

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify mergify bot added ci/build nvidia rocm Related to AMD ROCm cpu Related to CPU backends labels Feb 19, 2026
@mergify mergify bot added the v1 label Feb 19, 2026
@github-project-automation github-project-automation bot moved this to Todo in AMD Feb 19, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant new feature, Elastic Expert Parallelism, with extensive changes across the codebase. The core logic for state management and distributed communication has been heavily modified. While the overall architecture seems well-thought-out, I've identified a few critical issues related to distributed coordination and configuration that could lead to deadlocks or incorrect behavior during scaling operations. Specifically, there are potential deadlocks in the scale-down logic, incorrect calculations for resource allocation, and assumptions about parallelism dimensions that may not hold true. These issues need to be addressed to ensure the stability and correctness of the new elastic EP feature.

Comment thread vllm/config/parallel.py
Comment on lines +338 to +395
elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(use_new_group=False):
return False
if self.old_dp_group.rank() == 0:
self.old_dp_store.delete_key("eep_barrier_engine_count")
self._eplb_reshuffle_before_scale_down()
self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
# NOTE(yongji): currently, after EPLB reshuffle
# that redistributes experts to remaining workers, workers
# to be removed will immediately initiate shutdown;
# existing workers can no longer execute forward steps using
# the old setup. In the future, we may keep
# the removing workers alive a bit longer,
# e.g., to drain in-batch requests.
self._create_standby_groups()
self._switch_and_prepare()
self._update_parallel_config()
self.state = ScaleDownRemainingEngineState.COMPLETE
return True

else:
assert self.state == ScaleDownRemainingEngineState.COMPLETE
return True

def _progress_removing_engine(self) -> bool:
state = self.state

if state == ScaleDownRemovingEngineState.PREPARE:
self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
self.old_dp_store.add("eep_barrier_engine_count", 1)
return True

if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
if (
int(self.old_dp_store.get("eep_barrier_engine_count"))
< self.old_dp_group.size()
):
return False
if not self._staged_barrier(use_new_group=False):
return False
assert self.old_dp_group.rank() > 0
self._eplb_reshuffle_before_scale_down()
self._destroy_old_comm_groups()
self.state = ScaleDownRemovingEngineState.COMPLETE
self.engine_core._eep_send_engine_core_notification(
EEPNotificationType.SHUTDOWN_COMPLETE
)
self.engine_core.shutdown()
return True

else:
assert self.state == ScaleDownRemovingEngineState.COMPLETE
return True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a potential deadlock condition during the scale-down process. After the collective call to _eplb_reshuffle_before_scale_down, the execution paths for remaining and removing engines diverge. The _progress_removing_engine calls _destroy_old_comm_groups(), while _progress_remaining_engine calls _create_standby_groups(). Both of these methods trigger collective RPCs.

Since the workers are in different states and initiating different collective operations, this will lead to a deadlock, as not all participants of the original group will be part of the same collective call. To fix this, all workers in the old distributed groups must participate in the same collective operations in the same order until the groups are destroyed or the removing workers are definitively excluded from future collectives.

Comment thread vllm/distributed/parallel_state.py Outdated
Comment on lines +1829 to +1831
all_ranks = torch.arange(new_world_size_across_dp).reshape(
-1, new_dp_size, pp_size, tp_size
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The reshape operation for all_ranks in create_standby_groups does not account for prefill_context_model_parallel_size (pcp_size). It seems to assume pcp_size is always 1. If pcp_size > 1 is used with elastic EP, this will result in an incorrect tensor shape and subsequent errors when creating the standby process groups. The reshape operation should include prefill_context_model_parallel_size to correctly handle all parallelism dimensions.

Suggested change
all_ranks = torch.arange(new_world_size_across_dp).reshape(
-1, new_dp_size, pp_size, tp_size
)
all_ranks = torch.arange(new_world_size_across_dp).reshape(
-1, new_dp_size, pp_size, self.prefill_context_model_parallel_size, tp_size
)

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Feb 19, 2026

Hi @itayalroy, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Feb 19, 2026

Hi @itayalroy, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Feb 19, 2026

Hi @itayalroy, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

@mergify mergify bot removed the needs-rebase label Feb 26, 2026
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) February 26, 2026 19:03
@mergify
Copy link
Copy Markdown
Contributor

mergify bot commented Feb 26, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @itayalroy.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 26, 2026
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
@mergify mergify bot removed the needs-rebase label Feb 27, 2026
@tlrmchlsmth tlrmchlsmth merged commit dea2683 into vllm-project:main Feb 28, 2026
86 of 87 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Feb 28, 2026
@github-project-automation github-project-automation bot moved this from Todo to Done in AMD Feb 28, 2026
EanWang211123 pushed a commit to EanWang211123/vllm that referenced this pull request Mar 2, 2026
Signed-off-by: Yongji Wu <wuyongji317@gmail.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Ron Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
Signed-off-by: EanWang211123 <wangyiheng@sangfor.com.cn>
wangxiyuan pushed a commit to vllm-project/vllm-ascend that referenced this pull request Mar 6, 2026
### What this PR does / why we need it?
break:
- vllm-project/vllm#34102 
Disable_full param replaced with valid_modes/invalid_modes API
- vllm-project/vllm#35503
Now must return float compilation_time
- vllm-project/vllm#35564
New sequence_lengths param added
- vllm-project/vllm#33807
A check was performed (if runner_backend != "auto")
- vllm-project/vllm#34861
`BaseDeviceCommunicator` now accesses PyTorch's internal `pg_map` to
check process group state
- vllm-project/vllm#35274

**Important change:**
- vllm-project/vllm#28672

`matcher_utils` directly accesses `torch.ops._C.*` during the import
phase. In the Ascend environment, some unregistered ops trigger
`AttributeError`, causing e2e initialization failure.

https://github.com/vllm-project/vllm-ascend/actions/runs/22607260487/job/65502047131#step:10:2323

https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/matcher_utils.py#L29

This PR adds temporary compatibility placeholders (rms_norm,
fused_add_rms_norm, rotate_embedding, static/dynamic fp8 quant,
silu_and_mul) to
`vllm_ascend/patch/platform/patch_fusion_matcher_compat_ops.py` to
ensure no crashes during the import phase. Upstream repairs will be
considered later.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: Claude Code <noreply@anthropic.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
ananyakgarg pushed a commit to ananyakgarg/vllm that referenced this pull request Mar 6, 2026
Summary:
vllm-project#34861  moved `init_device()` after `_init_message_queues()` which breaks the multi-node TP as  `_init_message_queues` needs `_INNER_DP_WORLD` which is set in `init_device()`. This swaps the order back.

vllm-project#35503 also added `max(compilation_times)` but remote workers return None in multi-node, and this filters them out.

Test Plan: OSS

Differential Revision: D95475427
LCAIZJ pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Mar 7, 2026
### What this PR does / why we need it?
break:
- vllm-project/vllm#34102 
Disable_full param replaced with valid_modes/invalid_modes API
- vllm-project/vllm#35503
Now must return float compilation_time
- vllm-project/vllm#35564
New sequence_lengths param added
- vllm-project/vllm#33807
A check was performed (if runner_backend != "auto")
- vllm-project/vllm#34861
`BaseDeviceCommunicator` now accesses PyTorch's internal `pg_map` to
check process group state
- vllm-project/vllm#35274

**Important change:**
- vllm-project/vllm#28672

`matcher_utils` directly accesses `torch.ops._C.*` during the import
phase. In the Ascend environment, some unregistered ops trigger
`AttributeError`, causing e2e initialization failure.

https://github.com/vllm-project/vllm-ascend/actions/runs/22607260487/job/65502047131#step:10:2323

https://github.com/vllm-project/vllm/blob/main/vllm/compilation/passes/fusion/matcher_utils.py#L29

This PR adds temporary compatibility placeholders (rms_norm,
fused_add_rms_norm, rotate_embedding, static/dynamic fp8 quant,
silu_and_mul) to
`vllm_ascend/patch/platform/patch_fusion_matcher_compat_ops.py` to
ensure no crashes during the import phase. Upstream repairs will be
considered later.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
vllm-project/vllm@15d76f7

---------

Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: Meihan-chen <jcccx.cmh@gmail.com>
Co-authored-by: Claude Code <noreply@anthropic.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
Signed-off-by: Yongji Wu <wuyongji317@gmail.com>
Signed-off-by: Itay Alroy <ialroy@nvidia.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Signed-off-by: Ron Tourgeman <rtourgeman@nvidia.com>
Co-authored-by: Yongji Wu <wuyongji317@gmail.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Ron Tourgeman <rtourgeman@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build cpu Related to CPU backends frontend multi-modality Related to multi-modality (#4194) nvidia ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.

4 participants